import torch
from torch.autograd import Function

# C++ 扩展编译后会挂在包名下，名字与 ext-modules 的 name 字段一致
from my_reduce import _C  # noqa: F401


class MyReduceFunc(Function):
    @staticmethod
    def forward(ctx, y):
        ctx.save_for_backward(y)
        return _C.forward(y)

    @staticmethod
    def backward(ctx, grad_out):
        (y,) = ctx.saved_tensors
        return _C.backward(grad_out, y)


class MyReduce(torch.nn.Module):
    def forward(self, y):
        return MyReduceFunc.apply(y)


# 顶层 API 导出
__all__ = ["MyReduce", "MyReduceFunc"]
